# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC, abstractmethod

import numpy as np
import paddle
from paddlenlp_ops import (
    draft_model_postprocess,
    draft_model_preprocess,
    eagle_get_base_model_hidden_states,
    eagle_get_self_hidden_states,
)

from paddlenlp.transformers import AutoConfig, AutoInferenceModelForCausalLM
from paddlenlp.trl import llm_utils


class Proposer(ABC):
    """
    Abstract base class for all proposers that can be used in the speculative decoding framework.
    The subclasses of this class must implement the run method to get the draft tokens that are
    generated by the proposer.
    """

    def __init__(self, **kwargs):
        pass

    @abstractmethod
    def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
        """
        Get the draft tokens that are generated by the proposer.
        """
        raise NotImplementedError()

    @abstractmethod
    def insert_query(self, **kwargs):
        """
        Insert new query
        """
        pass

    @abstractmethod
    def postprocess(self, **kargs):
        """
        Postprocessing finished query
        """
        pass


class InferenceWithReferenceProposer(Proposer):
    """
    InferenceWithReference(https://arxiv.org/pdf/2304.04487) is one of the speculative decoding method.
    It match tokens in the input and output as draft tokens.
    """

    def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs):
        """
        Args:
        max_draft_token_num (int):
            Maximum number of tokens a proposer can generate at one time.
            The hyperparameter of k in the paper.
        max_ngram_size (int):
            The maximum size of the window used to match inputs and outputs.
            The hyperparameter of n in the paper.
        max_batch_size (int):
            The maximum batch size.
        max_seq_len (int):
            The maximum sequence length.
        """
        super().__init__()
        self.max_ngram_size = max_ngram_size
        self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu()
        self.input_ids_cpu = paddle.zeros(shape=[max_batch_size, max_seq_len], dtype="int64").cpu()
        self.max_batch_size = max_batch_size
        self.max_draft_token_num = max_draft_token_num

    def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
        """
        Use ngram_match to get draft tokens from the input and output.
        """
        draft_tokens = model_inputs["draft_tokens"].cpu()
        seq_lens_this_time = kargs["seq_lens_this_time"].cpu()
        seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu()
        seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu()

        from paddlenlp_ops import ngram_match

        ngram_match(
            self.input_ids_cpu,
            self.input_ids_len.cpu(),
            model_inputs["pre_ids"].cpu(),
            model_inputs["step_idx"].cpu(),
            model_inputs["actual_draft_token_num"].cpu(),
            draft_tokens,
            seq_lens_this_time,
            seq_lens_encoder,
            seq_lens_decoder,
            model_inputs["max_length"].cpu(),
            kargs["real_batch_size"],
            self.max_ngram_size,
            self.max_draft_token_num,
        )

        model_inputs["draft_tokens"][:] = draft_tokens.cuda()
        model_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda()
        kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()

    def insert_query(self, **kwargs):
        """
        Insert new query
        """
        pass

    def postprocess(self, **kwargs):
        """
        Postprocessing finished query
        """


class ModelProposer(Proposer):
    """
    用于类 Model 的 Proposer 基类
    在输入输出中匹配符合的tokens作为 draft tokens
    """

    def __init__(self, args, **kwargs):
        super().__init__()
        self.args = self.build_args(args)
        self.model_args = kwargs.get("model_args", None)
        self.draft_type = args.speculate_method
        self.dtype = self.args.dtype
        assert self.draft_type in (
            "draft_model",
            "eagle",
            "mtp",
        ), f"draft_type support [draft_model, eagle], but get {self.draft_type}"

        self.max_draft_tokens = self.args.speculate_max_draft_token_num
        self.actual_draft_token_num = self.max_draft_tokens
        self.batch_size = self.args.batch_size
        self.init_predictor()

    def build_args(self, args):
        from copy import deepcopy

        draft_model_args = deepcopy(args)
        draft_model_args.quant_type = args.draft_model_quant_type
        draft_model_args.model_name_or_path = args.draft_model_name_or_path
        draft_model_args.decode_strategy = "draft_model_sample"
        draft_model_args.mode = "dynamic"
        draft_model_args.return_full_hidden_states = 0
        return draft_model_args

    def init_predictor(self):
        """
        init_predictor
        """

        tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env()

        self.config = AutoConfig.from_pretrained(self.args.draft_model_name_or_path)
        self.model = AutoInferenceModelForCausalLM.from_pretrained(
            self.args.model_name_or_path,
            config=self.config,
            predictor_args=self.args,
            model_args=self.model_args,
            dtype=self.args.dtype,
            tensor_parallel_degree=tensor_parallel_degree,
            tensor_parallel_rank=tensor_parallel_rank,
            spec_model_type=self.draft_type,
        )

        # prepare model_inputs
        self.model_inputs = {}

        self.cache_kvs_shape = self.model.get_cache_kvs_shape(self.model.config, self.args.batch_size)
        cachekv_dtype = self.dtype if self.config.cachekv_int8_type is None else "uint8"
        self.cache_kvs = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in self.cache_kvs_shape]

        self.max_block_nums = self.cache_kvs_shape[0][0]
        self.free_list = list(range(self.max_block_nums))
        self.pre_ids = paddle.to_tensor(np.zeros((self.batch_size, self.args.total_max_length)).astype("int64") - 1)
        self.rope_theta = self.config.get("rope_theta", 10000.0)
        self.rope_scaling = self.config.get("rope_scaling", None)

        self.head_dim = self.cache_kvs_shape[0][-1]
        self.rope_emb = llm_utils.get_rotary_position_embedding(
            paddle.arange(self.args.total_max_length).reshape((1, -1)),
            self.head_dim,
            self.rope_theta,
            self.rope_scaling,
        )

    def run(self, share_inputs, **kwargs):
        self.run_preprocess(share_inputs)
        self.run_infer(share_inputs, **kwargs)
        self.run_postprocess(share_inputs)

    def insert_query(self, **kwargs):
        real_bs = kwargs.get("real_bs")
        seq_lens = kwargs.get("seq_lens")
        base_model_inputs = kwargs.get("base_model_inputs")

        max_sec_len = self.args.total_max_length
        self.model_inputs["block_tables"] = paddle.full_like(
            base_model_inputs["block_tables"], fill_value=-1, dtype="int32"
        )
        for i in range(real_bs):
            real_len = seq_lens[i] + self.args.max_length
            if real_len > max_sec_len:
                self.free_list = list(range(self.max_block_nums))
                # self.used_list = [[] for _ in range(self.beam_batch_size)]
                raise ValueError(
                    f"input_len({seq_lens[i]}) + \
max_dec_len({self.args.max_length}) > max_seq_len({max_sec_len})"
                )
            for j in range((real_len + self.args.block_size - 1) // self.args.block_size):
                used_block_id = self.free_list.pop()
                self.model_inputs["block_tables"][i, j] = used_block_id

        self.model_inputs["input_ids"] = paddle.clone(base_model_inputs["input_ids"])
        self.model_inputs["seq_lens_this_time"] = paddle.clone(base_model_inputs["seq_lens_this_time"])
        self.model_inputs["seq_lens_encoder"] = paddle.clone(base_model_inputs["seq_lens_encoder"])
        self.model_inputs["seq_lens_decoder"] = paddle.clone(base_model_inputs["seq_lens_decoder"])
        self.model_inputs["step_idx"] = paddle.clone(base_model_inputs["step_idx"])
        self.model_inputs["stop_flags"] = paddle.clone(base_model_inputs["stop_flags"])
        self.model_inputs["stop_nums"] = paddle.clone(base_model_inputs["stop_nums"])
        self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
        self.model_inputs["pre_ids"] = self.pre_ids
        self.model_inputs["rope_emb"] = self.rope_emb
        self.model_inputs["cache_kvs"] = self.cache_kvs
        self.model_inputs["top_p"] = base_model_inputs["top_p"]
        self.model_inputs["temperature"] = base_model_inputs["temperature"]
        self.model_inputs["eos_token_id"] = base_model_inputs["eos_token_id"]
        self.model_inputs["penalty_score"] = base_model_inputs["penalty_score"]
        self.model_inputs["frequency_score"] = base_model_inputs["frequency_score"]
        self.model_inputs["presence_score"] = base_model_inputs["presence_score"]
        self.model_inputs["max_length"] = base_model_inputs["max_length"]
        self.model_inputs["min_length"] = base_model_inputs["min_length"]
        self.model_inputs["bad_tokens"] = base_model_inputs["bad_tokens"]
        self.model_inputs["next_tokens"] = paddle.full(shape=[self.batch_size, 1], fill_value=-1, dtype="int64")
        self.model_inputs["base_model_draft_tokens"] = base_model_inputs["draft_tokens"]
        self.model_inputs["draft_tokens"] = paddle.full(shape=[self.batch_size, 2], fill_value=-1, dtype="int64")

        self.first_token_record = paddle.full(shape=[self.batch_size, 1], fill_value=-1, dtype="int32")
        self.model_inputs["substep"] = 0
        for i in range(real_bs):
            self.model_inputs["pre_ids"][i, 0] = self.model_inputs["input_ids"][i, -1]
            self.first_token_record[i : i + 1] = seq_lens[i]

    def run_preprocess(self, share_inputs):
        """
        update draft model parameteds
        """
        draft_model_preprocess(
            self.model_inputs["draft_tokens"],
            self.model_inputs["input_ids"],
            self.model_inputs["stop_flags"],
            self.model_inputs["seq_lens_this_time"],
            self.model_inputs["seq_lens_encoder"],
            self.model_inputs["seq_lens_decoder"],
            self.model_inputs["step_idx"],
            self.first_token_record,
            self.model_inputs["not_need_stop"],
            share_inputs["accept_tokens"],
            share_inputs["accept_num"],
            share_inputs["seq_lens_encoder"],
            share_inputs["seq_lens_decoder"],
            share_inputs["step_idx"],
            share_inputs["stop_flags"],
            share_inputs["draft_tokens"],
            self.max_draft_tokens,
            self.draft_type in ["eagle", "mtp"],
        )

    def run_infer(self, share_inputs, **kwargs):
        """
        Should be implemented by subclasses.
        """
        raise NotImplementedError("Subclasses mut implement this function")

    def run_postprocess(self, share_inputs):
        """
        Update base model draft_tokens
        """
        draft_model_postprocess(
            share_inputs["draft_tokens"],
            share_inputs["seq_lens_this_time"],
            share_inputs["seq_lens_encoder"],
            share_inputs["stop_flags"],
        )

    def postprocess(self, base_model_inputs):
        for i in range(self.batch_size):
            if not base_model_inputs["stop_flags"][i]:
                break
        self.pre_ids[:] = -1
        self.free_list = list(range(self.max_block_nums))
        # self.used_list = [[] for _ in range(self.beam_batch_size)]


# Not verified Now. Reserverd API.
class DraftModelProposer(ModelProposer):
    """
    用于 Draft Model 的 Proposer
    在输入输出中匹配符合的tokens作为 draft tokens
    """

    def insert_query(self, **kwargs):
        super().insert_query(**kwargs)
        real_bs = kwargs.get("real_bs")
        self.model_inputs["seq_lens_encoder"] += 1
        self.model_inputs["seq_lens_this_time"] += 1
        for i in range(real_bs):
            self.first_token_record[i : i + 1] += 1

    def run_infer(self, share_inputs, **kwargs):
        if self.model_inputs["not_need_stop"]:
            with paddle.no_grad():
                self.model_inputs["substep"] = 0
                while self.model_inputs["substep"] < self.max_draft_tokens and self.model_inputs["not_need_stop"]:
                    self.model(**self.model_inputs)
                    self.model_inputs["substep"] += 1


class EagleProposer(ModelProposer):
    """
    用于 Eagle 的 Proposer
    在输入输出中匹配符合的tokens作为 draft tokens
    """

    def insert_query(self, **kwargs):
        super().insert_query(**kwargs)
        base_model_inputs = kwargs.get("base_model_inputs")

        self.model_inputs["input_ids"][:, :-1] = base_model_inputs["input_ids"][:, 1:]
        self.last_seq_lens_this_time = paddle.full_like(
            base_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
        )

    def run_infer(self, share_inputs, **kwargs):
        base_model_full_hidden_states = kwargs.get("base_model_full_hidden_states", None)
        if self.model_inputs["not_need_stop"]:
            base_model_hidden_states = eagle_get_base_model_hidden_states(
                base_model_full_hidden_states,
                self.model_inputs["seq_lens_this_time"],
                self.model_inputs["seq_lens_encoder"],
                self.model_inputs["seq_lens_decoder"],
                self.model_inputs["stop_flags"],
                share_inputs["accept_num"],
                share_inputs["seq_lens_this_time"],
                share_inputs["seq_lens_encoder"],
                self.actual_draft_token_num,
            )
            self.model_inputs["hidden_states"] = base_model_hidden_states

        with paddle.no_grad():
            self.model_inputs["substep"] = 0
            while self.model_inputs["not_need_stop"] and self.model_inputs["substep"] < self.max_draft_tokens:
                self.last_seq_lens_this_time[:] = self.model_inputs["seq_lens_this_time"][:]
                output_hidden_states = self.model.generate(**self.model_inputs)
                self.model_inputs["substep"] += 1
                if self.model_inputs["not_need_stop"] and self.model_inputs["substep"] < self.actual_draft_token_num:
                    self.model_inputs["hidden_states"] = eagle_get_self_hidden_states(
                        output_hidden_states,
                        self.last_seq_lens_this_time,
                        self.model_inputs["seq_lens_this_time"],
                        self.model_inputs["step_idx"],
                    )
                else:
                    self.model_inputs["hidden_states"] = None
